/**
 * Helper utilities to copy a 3D tensor into a larger 3D tensor
 * and back. Used in ConcatLayer.
 */
#define SHIFTCOPY_BOUNDS_AND_INDEX \
  int x = threadIdx.x + blockIdx.x * blockDim.x; \
  int y  = threadIdx.y + blockIdx.y * blockDim.y; \
  int z  = threadIdx.z + blockIdx.z * blockDim.z; \
  if (x >= dim_x || y >= dim_y || z >= dim_z) \
    return; \
  int idx = x + dim_x*(y + dim_y*z); \
  int idx_shift = x + dim_x*(y + shift + dim_y2*z)

// width2, height2 and channels2 are geometry for the bigger tensor that we are copying into.
template <typename T>
__device__ void copy_to_shifted(T *dst, T *src, int dim_x, int dim_y, int dim_z, int dim_y2, int shift) {
  SHIFTCOPY_BOUNDS_AND_INDEX;
  dst[idx_shift] = src[idx];
}
template <typename T>
__device__ void copy_from_shifted(T *dst, T *src, int dim_x, int dim_y, int dim_z, int dim_y2, int shift) {
  SHIFTCOPY_BOUNDS_AND_INDEX;
  dst[idx] = src[idx_shift];
}

#define DEF_SHIFT_COPY(name, dtype) \
  __global__ void name ## _ ## dtype(dtype *dst, dtype *src, int dim_x, int dim_y, int dim_z, int dim_y2, int shift) { \
    name(dst, src, dim_x, dim_y, dim_z, dim_y2, shift); \
  }

extern "C" {
DEF_SHIFT_COPY(copy_to_shifted, float)
DEF_SHIFT_COPY(copy_to_shifted, double)
DEF_SHIFT_COPY(copy_from_shifted, float)
DEF_SHIFT_COPY(copy_from_shifted, double)
} // extern "C"

// vim: ft=cuda
